#########
## Nature Human Behaviour 
## Leading Countries in Global Science Increasingly Receive More Citations than Other Countries Despite Doing Similar Research.
## https://doi.org/10.1038/s41562-022-01351-5
## Harvard Dataverse (Code and Metadata): https://doi.org/10.7910/DVN/WCOINR 
## Step 0A
## Data: Data_20210905
#########

# At a terminal, run: 
# '''Note: '$1' is the discipline ID that you pass in.'''
# ml python/2.7.13 py-scipy/1.1.0_py27 py-scikit-learn/0.19.1_py27 py-numpy/1.14.3_py27 py-pandas/0.23.0_py27
# srun python Step_X0A_Python2_Athena_MAG_Field_Metadata.py "$1"

# Installing PyAthena: 
# https://pypi.org/project/PyAthena/
# PYTHONUSERBASE=$GROUP_HOME/python pip install --user PyAthena
# export PYTHONPATH=$GROUP_HOME/python/lib/python2.7/site-packages:$PYTHONPATH


#############################
### Input
#############################
# Sociology - 144024400
# Atomic Physics - 184779094
# Information Dissemination - 2779494480
# Internal medicine - 126322002

import sys
discipline = str(sys.argv[1])

#############################
### Modules
#############################
from pyathena import connect
import pandas as pd
import json
import numpy as np
import nltk
import string
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk import everygrams
import gc 
from nltk.stem import *
from nltk.tokenize import RegexpTokenizer
from sklearn.feature_extraction.text import CountVectorizer 
from nltk.stem.porter import *
from nltk.stem.snowball import SnowballStemmer
import sys
from collections import Counter
import random
import bz2 
import pickle
import cPickle
import itertools
import multiprocessing as mp
from os import path

reload(sys)
sys.setdefaultencoding("utf-8")

#############################
### Time Start
#############################
import time 
start_time = time.time()


#############################
### Functions
#############################

def reduce_mem_usage(df):
    """ 
    iterate through all the columns of a dataframe and 
    modify the data type to reduce memory usage.        
    """
    start_mem = df.memory_usage().sum() / 1024**2
   
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max <\
                  np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max <\
                   np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max <\
                   np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max <\
                   np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max <\
                   np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max <\
                   np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            next    
    return df

def divide_chunks(l, n): 
    # looping till length l 
    for i in range(0, len(l), n):  
        yield l[i:i + n]

import itertools
def split_dict(x, chunks):      
    i = itertools.cycle(range(chunks))       
    split = [dict() for _ in range(chunks)]
    for k, v in x.items():
        split[next(i)][k] = v
    return split

def invertedindex_to_string(x):
  	inverted_index = json.loads(x)["InvertedIndex"] 
  	index = {k: str(oldk.encode('utf-8')).strip().lower() for oldk, oldv in inverted_index.items() for k in oldv}
  	return index.values()

def fieldIDString(x,input_discipline):
	return x.replace("FIELDIDHERE", input_discipline)

###############################
### Connections
###############################

cursor = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED').cursor()

conn = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED')


###############################
### Paper ID -- Year and Title
###############################

check_paper_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_paper_FIELDIDHERE');
'''

drop_paper_query = '''
drop table if exists mag_staging.temp_paper_FIELDIDHERE;
'''

paper_query = '''
create table mag_staging.temp_paper_FIELDIDHERE 
WITH (
    format = 'TEXTFILE',
    field_delimiter = '\t'
  )
as 
SELECT Paper_Table.paperid as PaperID,
       Paper_Table.papertitle as Title,
       Paper_Table.full_year as Year,
       Paper_Table.doctype as DocType
FROM (
  Select Paper_Discipline_Table.paperid, 
         Discipline_Table.displayname
  FROM "mag_data"."paperfieldsofstudy" as Paper_Discipline_Table
  JOIN "mag_data"."fieldsofstudy" as Discipline_Table
  ON Discipline_Table.fieldofstudyid = Paper_Discipline_Table.fieldofstudyid
  WHERE Discipline_Table.fieldofstudyid = FIELDIDHERE) as Paper_Discipline_Key_Table
JOIN "mag_data"."papers" as Paper_Table
ON Paper_Table.paperid = Paper_Discipline_Key_Table.paperid;
'''

if pd.read_sql(fieldIDString(check_paper_query,discipline), conn)["returnTrue"].empty == True:
  cursor.execute(fieldIDString(paper_query,discipline))

################################
## Paper ID and Title
################################
read_paper_query = '''
SELECT * FROM mag_staging.temp_paper_FIELDIDHERE
'''
df_year = pd.read_sql(fieldIDString(read_paper_query,discipline), conn)
df_year = df_year[["paperid","year","title"]]

dict_year = df_year[["paperid","year"]].to_dict()
#del df_year # Delete a bit later. 
gc.collect()

#################################
#### Author Affiliations
#################################
check_author_affiliations_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_author_affiliation_FIELDIDHERE');
'''

drop_author_affiliations_query = '''
drop table if exists mag_staging.temp_author_affiliation_FIELDIDHERE;
'''

author_affiliations_query = '''
create table mag_staging.temp_author_affiliation_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
    field_delimiter = '\t'
  )
as 
SELECT  Paper_Discipline_Key_Table.paperid as PaperID,
        Affiliations_Table.affiliationid as AffiliationID,
        Affiliations_Table.normalizedname as Normalized_Name,
        Affiliations_Table.gridid as GridID,
        count(*) as NumAuthors
FROM (
  Select Paper_Discipline_Table.paperid 
  FROM "mag_data"."paperfieldsofstudy" as Paper_Discipline_Table
  WHERE Paper_Discipline_Table.fieldofstudyid = FIELDIDHERE) as Paper_Discipline_Key_Table
JOIN "mag_data"."paperauthoraffiliations" as Paper_Affiliations_Table
ON Paper_Affiliations_Table.paperid = Paper_Discipline_Key_Table.paperid 
JOIN "mag_data"."affiliations" as Affiliations_Table
ON Paper_Affiliations_Table.paperid = Paper_Discipline_Key_Table.paperid
AND Affiliations_Table.affiliationid = Paper_Affiliations_Table.affiliationid
GROUP BY Paper_Discipline_Key_Table.paperid, Affiliations_Table.affiliationid, Affiliations_Table.normalizedname, Affiliations_Table.gridid
'''

if pd.read_sql(fieldIDString(check_author_affiliations_query,discipline), conn)["returnTrue"].empty == True:
	cursor.execute(fieldIDString(author_affiliations_query,discipline))

##################################
#### Citations Affiliations
##################################
check_citation_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_citation_edgelist_what_paper_is_citing_FIELDIDHERE');
'''

drop_citation_query = '''
drop table if exists mag_staging.temp_citation_edgelist_what_paper_is_citing_FIELDIDHERE;
'''

citation_query = '''
create table mag_staging.temp_citation_edgelist_what_paper_is_citing_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
    field_delimiter = '\t'
  )
as 
SELECT Paper_Reference_Discipline_Key_Table.paperid as PaperID,
       Paper_Reference_Discipline_Key_Table.paperreferenceid as CitationID,
        Affiliations_Table.gridid as CitationGridID,
        count(*) as NumAuthors
FROM (
  Select Paper_Discipline_Table.paperid,  
         Citations_Table.paperreferenceid
  FROM "mag_data"."paperfieldsofstudy" as Paper_Discipline_Table
  JOIN "mag_data"."paperreferences" as Citations_Table
  ON Paper_Discipline_Table.paperid = Citations_Table.paperid
  WHERE Paper_Discipline_Table.fieldofstudyid = FIELDIDHERE
      ) as Paper_Reference_Discipline_Key_Table
JOIN "mag_data"."paperauthoraffiliations" as Paper_Affiliations_Table
ON Paper_Affiliations_Table.paperid = Paper_Reference_Discipline_Key_Table.paperreferenceid 
JOIN "mag_data"."affiliations" as Affiliations_Table
ON Paper_Affiliations_Table.paperid = Paper_Reference_Discipline_Key_Table.paperreferenceid
AND Affiliations_Table.affiliationid = Paper_Affiliations_Table.affiliationid
GROUP BY Paper_Reference_Discipline_Key_Table.paperid, Paper_Reference_Discipline_Key_Table.paperreferenceid, Affiliations_Table.gridid
'''

if pd.read_sql(fieldIDString(check_citation_query,discipline), conn)["returnTrue"].empty == True:
	cursor.execute(fieldIDString(citation_query,discipline))

##################################
#### Citations Year
##################################
check_citation_year_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_citation_edgelist_year_FIELDIDHERE');
'''

drop_citation_year_query = '''
drop table if exists mag_staging.temp_citation_edgelist_year_FIELDIDHERE;
'''

citation_query_year = '''
create table mag_staging.temp_citation_edgelist_year_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
  	field_delimiter = '\t'
  )
as 
SELECT Papers_Citing.full_year as CitingYear,
       Papers_Cited.full_year as CitedYear,
       count(*) as NumCitation
FROM (
  Select Paper_Discipline_Table.paperid,  
         Citations_Table.paperreferenceid
  FROM "mag_data"."paperfieldsofstudy" as Paper_Discipline_Table
  JOIN "mag_data"."paperreferences" as Citations_Table
  ON Paper_Discipline_Table.paperid = Citations_Table.paperid
  WHERE Paper_Discipline_Table.fieldofstudyid = FIELDIDHERE
      ) as Paper_Reference_Discipline_Key_Table
JOIN "mag_data"."papers" as Papers_Citing
ON Papers_Citing.paperid = Paper_Reference_Discipline_Key_Table.paperid 
JOIN "mag_data"."papers" as Papers_Cited
ON Papers_Cited.paperid = Paper_Reference_Discipline_Key_Table.paperreferenceid
GROUP BY Papers_Citing.full_year, Papers_Cited.full_year
'''

if pd.read_sql(fieldIDString(check_citation_year_query,discipline), conn)["returnTrue"].empty == True:
	cursor.execute(fieldIDString(citation_query_year,discipline))

##################################
#### Migration Affiliations
##################################
check_migration_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_migration_table_FIELDIDHERE');
'''

drop_migration_query = '''
drop table if exists mag_staging.temp_migration_table_FIELDIDHERE;
'''

migration_query = '''
create table mag_staging.temp_migration_table_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
  	field_delimiter = '\t'
  )
as
SELECT b.authorid, a.Year, c.gridid
FROM "mag_staging"."temp_paper_FIELDIDHERE" a
JOIN "mag_data"."paperauthoraffiliations" b
ON a.paperid = b.paperid
JOIN "mag_data"."affiliations" c
ON b.affiliationid = c.affiliationid
WHERE c.gridid!=''
ORDER by b.authorid, a.Year
'''

if pd.read_sql(fieldIDString(check_migration_query,discipline), conn)["returnTrue"].empty == True:
	cursor.execute(fieldIDString(migration_query,discipline))


#####################################
##### Final Edgelists
#####################################
check_citation_GRID_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_citation_new_full_edgelist_FIELDIDHERE');
'''

drop_citation_GRID_query = '''
drop table if exists mag_staging.temp_citation_new_full_edgelist_FIELDIDHERE;
'''

citation_GRID_query = '''
create table mag_staging.temp_citation_new_full_edgelist_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
    field_delimiter = '\t'
  )
as 
SELECT a.paper_gridid,
         b.citationgridid,
         a.Year,
         SUM(percent_citingauthors * percent_citedauthors) AS Citations
FROM 
    (SELECT d.paperid,
         c.Year,
         d.gridid AS paper_gridid,
         cast(d.numauthors AS decimal(6,
         2)) / sum(d.numauthors) over(partition by d.paperid) AS percent_citingauthors
    FROM "mag_staging"."temp_author_affiliation_FIELDIDHERE" d
    JOIN "mag_staging"."temp_paper_FIELDIDHERE" c
        ON d.paperid = c.paperid
    WHERE d.gridid!='') a
JOIN 
    (SELECT paperid,
         citationid,
         citationgridid,
         cast(numauthors AS decimal(6,
         2)) / sum(numauthors) over(partition by paperid,
        citationid) AS percent_citedauthors
    FROM "mag_staging"."temp_citation_edgelist_what_paper_is_citing_FIELDIDHERE"
    WHERE citationgridid!='') b
    ON a.paperid = b.paperid
GROUP BY  paper_gridid, citationgridid, Year 
'''

if pd.read_sql(fieldIDString(check_citation_GRID_query,discipline), conn)["returnTrue"].empty == True:
    cursor.execute(fieldIDString(citation_GRID_query,discipline))

df_citation_edgelist = pd.read_sql(fieldIDString('''Select * From "mag_staging"."temp_citation_new_full_edgelist_FIELDIDHERE" ''',discipline), conn)
df_citation_edgelist = reduce_mem_usage(df_citation_edgelist)

dict_citation_edgelist = df_citation_edgelist.to_dict()
del df_citation_edgelist
gc.collect()

citation_by_year_edgelist = '''
SELECT *
FROM "mag_staging"."temp_citation_edgelist_year_FIELDIDHERE"
'''
df_citation_year_edgelist = pd.read_sql(fieldIDString(citation_by_year_edgelist,discipline), conn)

dict_citation_year_edgelist = df_citation_year_edgelist.to_dict()
del df_citation_year_edgelist
gc.collect()

coauthor_edgelist = '''
SELECT  a.gridid as gridid_i, b.gridid as gridid_j, c.year, sum((a.numauthors * b.numauthors)) as sumnumauthors 
FROM "mag_staging"."temp_author_affiliation_FIELDIDHERE" a
JOIN "mag_staging"."temp_paper_FIELDIDHERE" c
ON a.paperid = c.paperid
JOIN "mag_staging"."temp_author_affiliation_FIELDIDHERE" b
ON a.paperid = b.paperid
WHERE a.gridid!=b.gridid
AND a.gridid!=''
AND b.gridid!=''
GROUP BY a.gridid, b.gridid, c.year
'''
df_coauthor_edgelist = pd.read_sql(fieldIDString(coauthor_edgelist,discipline), conn)

dict_coauthor_edgelist = df_coauthor_edgelist.to_dict()
del df_coauthor_edgelist
gc.collect()

mobility_edgelist = '''
SELECT all_career.year, all_career.arrivalGRID, all_career.departureGRID,five_year.nummoves_five_year,all_career.nummoves_all_career 
FROM (
  SELECT arrival.year, arrival.gridid as arrivalGRID , departure.gridid as departureGRID, count(*) as nummoves_five_year
FROM "mag_staging"."temp_migration_table_FIELDIDHERE" arrival
JOIN "mag_staging"."temp_migration_table_FIELDIDHERE" departure
ON arrival.authorid=departure.authorid
WHERE arrival.year>departure.year
AND arrival.gridid!=departure.gridid
AND (arrival.year-departure.year)<=5
GROUP BY arrival.gridid, departure.gridid, arrival.year) five_year
JOIN (
  SELECT arrival.year, arrival.gridid as arrivalGRID, departure.gridid as departureGRID, count(*) as nummoves_all_career
FROM "mag_staging"."temp_migration_table_FIELDIDHERE" arrival
JOIN "mag_staging"."temp_migration_table_FIELDIDHERE" departure
ON arrival.authorid=departure.authorid
WHERE arrival.year>departure.year
AND arrival.gridid!=departure.gridid
GROUP BY arrival.gridid, departure.gridid, arrival.year
  ) all_career
  ON five_year.year=all_career.year
  AND five_year.arrivalGRID = all_career.arrivalGRID
  AND five_year.departureGRID = all_career.departureGRID
'''
df_mobility_edgelist = pd.read_sql(fieldIDString(mobility_edgelist,discipline), conn)
dict_mobility_edgelist = df_mobility_edgelist.to_dict()

del df_mobility_edgelist
gc.collect()


#######
read_affiliation_query = '''
SELECT * FROM mag_staging.temp_author_affiliation_FIELDIDHERE
'''
df_labels = pd.read_sql(fieldIDString(read_affiliation_query,discipline), conn)
df_grid_labels = df_labels[["paperid","gridid"]]

del df_labels
gc.collect()

dict_grid_labels = df_grid_labels.to_dict()
del df_grid_labels
gc.collect()

dict_year = df_year.to_dict()
del df_year
gc.collect()


###########
##### Print out objects to disk via Pickle
###########
Field_MetaData_Filename = "OUTPUT_Python_MAG_Field_MetaData_Dict_"+str(discipline)+".pbz2"

with bz2.BZ2File(Field_MetaData_Filename, 'w') as f:
	cPickle.dump(dict_grid_labels, f)
	cPickle.dump(dict_year, f)
	cPickle.dump(dict_coauthor_edgelist, f)
	cPickle.dump(dict_mobility_edgelist, f)
	cPickle.dump(dict_citation_edgelist, f)
	cPickle.dump(dict_citation_year_edgelist, f)
f.close()

del dict_grid_labels
del dict_year
del dict_coauthor_edgelist
del dict_mobility_edgelist
del dict_citation_edgelist
del dict_citation_year_edgelist
gc.collect()

# do your work here
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))

###########
#### Remove Temporary Tables
###########
cursor.execute(fieldIDString(drop_author_affiliations_query,discipline))
cursor.execute(fieldIDString(drop_citation_query,discipline))
cursor.execute(fieldIDString(drop_citation_year_query,discipline))
cursor.execute(fieldIDString(drop_migration_query,discipline))

######## Drop Paper Table IF All Corpora and Metadata Exist As Well
Field_RAKE_Corpus_Filename = "OUTPUT_Python_MAG_Field_Corpus_RAKE_"+str(discipline)+".pbz2"
Field_MetaData_Filename = "OUTPUT_Python_MAG_Field_MetaData_Dict_"+str(discipline)+".pbz2"


if path.exists(Field_RAKE_Corpus_Filename) and path.exists(Field_MetaData_Filename):
    cursor.execute(fieldIDString(drop_paper_query,discipline))


## Test to read in the cPickle object
# f = bz2.BZ2File(Field_MetaData_Filename, 'rb')
# dict_grid_labels = cPickle.load(f)
# dict_year = cPickle.load(f)
# dict_coauthor_edgelist = cPickle.load(f)
# dict_mobility_edgelist = cPickle.load(f)
# dict_citation_edgelist = cPickle.load(f)
# dict_citation_year_edgelist = cPickle.load(f)

